{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导包\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris  #鸢尾花数据集\n",
    "from sklearn.model_selection import train_test_split,GridSearchCV,cross_val_score\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入数据\n",
    "X = load_iris().data\n",
    "y = load_iris().target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.DataFrame(X,columns=load_iris().feature_names)\n",
    "data['label'] = y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "0                5.1               3.5                1.4               0.2   \n",
       "1                4.9               3.0                1.4               0.2   \n",
       "2                4.7               3.2                1.3               0.2   \n",
       "3                4.6               3.1                1.5               0.2   \n",
       "4                5.0               3.6                1.4               0.2   \n",
       "\n",
       "   label  \n",
       "0      0  \n",
       "1      0  \n",
       "2      0  \n",
       "3      0  \n",
       "4      0  "
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(150, 5)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 切分数据集\n",
    "X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "StandardScaler()"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用标准包，对训练集来学习，从而对训练集和测试集做标准化\n",
    "std = StandardScaler()\n",
    "std.fit(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([5.88380952, 3.04952381, 3.84761905, 1.2247619 ]),\n",
       " array([0.80700785, 0.4392903 , 1.70953044, 0.73714778]))"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "std.mean_,std.scale_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_stand = std.transform(X_train)\n",
    "X_test_stand = std.transform(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9714285714285714"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = LogisticRegression()\n",
    "clf = clf.fit(X_train_stand,y_train)\n",
    "clf.score(X_train_stand,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_validation.py:372: FitFailedWarning: \n",
      "95 fits failed out of a total of 380.\n",
      "The score on these train-test partitions for these parameters will be set to nan.\n",
      "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
      "\n",
      "Below are more details about the failures:\n",
      "--------------------------------------------------------------------------------\n",
      "95 fits failed with the following error:\n",
      "Traceback (most recent call last):\n",
      "  File \"E:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 680, in _fit_and_score\n",
      "    estimator.fit(X_train, y_train, **fit_params)\n",
      "  File \"E:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1461, in fit\n",
      "    solver = _check_solver(self.solver, self.penalty, self.dual)\n",
      "  File \"E:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 436, in _check_solver\n",
      "    % (all_solvers, solver)\n",
      "ValueError: Logistic Regression supports only solvers in ['liblinear', 'newton-cg', 'lbfgs', 'sag', 'saga'], got Ibfgs.\n",
      "\n",
      "  warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
      "E:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_search.py:972: UserWarning: One or more of the test scores are non-finite: [0.86666667        nan 0.82857143 0.86666667 0.9047619         nan\n",
      " 0.83809524 0.9047619  0.92380952        nan 0.85714286 0.92380952\n",
      " 0.93333333        nan 0.84761905 0.93333333 0.93333333        nan\n",
      " 0.86666667 0.93333333 0.94285714        nan 0.86666667 0.94285714\n",
      " 0.94285714        nan 0.86666667 0.94285714 0.94285714        nan\n",
      " 0.86666667 0.94285714 0.95238095        nan 0.86666667 0.95238095\n",
      " 0.95238095        nan 0.87619048 0.95238095 0.95238095        nan\n",
      " 0.87619048 0.95238095 0.95238095        nan 0.87619048 0.95238095\n",
      " 0.95238095        nan 0.87619048 0.95238095 0.95238095        nan\n",
      " 0.87619048 0.95238095 0.95238095        nan 0.87619048 0.95238095\n",
      " 0.95238095        nan 0.88571429 0.95238095 0.96190476        nan\n",
      " 0.88571429 0.96190476 0.96190476        nan 0.88571429 0.96190476\n",
      " 0.96190476        nan 0.88571429 0.96190476]\n",
      "  category=UserWarning,\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9619047619047618"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 在确定l2范式的情况下，使用网格搜索判断solver, C的最优组合\n",
    "p = {\n",
    "    'C':list(np.linspace(0.05,1,19)),\n",
    "    'solver':['newton-cg','Ibfgs','liblinear','sag']\n",
    "\n",
    "}\n",
    "model = LogisticRegression(penalty='l2')\n",
    "GS = GridSearchCV(model,p,cv=5)\n",
    "GS.fit(X_train_stand,y_train)\n",
    "GS.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C': 0.8944444444444445, 'solver': 'newton-cg'}"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "GS.best_params_  #最优参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['sepal length (cm)',\n",
       " 'sepal width (cm)',\n",
       " 'petal length (cm)',\n",
       " 'petal width (cm)']"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将最优的结果重新用来实例化模型，查看训练集和测试集下的分数\n",
    "(注意多分类需要增加参数  average='micro')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9714285714285714"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf_best = LogisticRegression(penalty='l2',\n",
    "                  C=0.8944444444444445,\n",
    "                  solver='newton-cg')\n",
    "clf_best.fit(X_train_stand,y_train)\n",
    "clf_best.score(X_train_stand,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf_best.fit(X_test_stand,y_test)\n",
    "clf_best.score(X_test_stand,y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
